{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3XX46cTrh6iD"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sKrlWr6Kh-mF"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hST65kOHXpiL"
      },
      "source": [
        "# Transfer Learning for the Audio Domain with TensorFlow Lite Model Maker\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_audio_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_audio_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_audio_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_audio_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://tfhub.dev/google/yamnet/1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BB5k6xNKJ5Xe"
      },
      "source": [
        "\n",
        "In this colab notebook, you'll learn how to use the [TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/guide/model_maker) to train a custom audio classification model.\n",
        "\n",
        "The Model Maker library uses transfer learning to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data and time required.\n",
        "\n",
        "It is part of the [Codelab to Customize an Audio model and deploy on Android](https://codelabs.developers.google.com/codelabs/tflite-audio-classification-custom-model-android).\n",
        "\n",
        "You'll use a custom birds dataset and export a TFLite model that can be used on a phone, a TensorFlow.JS model that can be used for inference in the browser and also a SavedModel version that you can use for serving.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AeZZ_cSsZfPx"
      },
      "source": [
        "## Intalling dependencies\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wbMc4vHjaYdQ"
      },
      "outputs": [],
      "source": [
        "! pip install tflite-model-maker"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z2ck_Ghdcgt9"
      },
      "source": [
        "## Import TensorFlow, Model Maker and other libraries\n",
        "\n",
        "Among the dependencies that are needed, you'll use TensorFlow and Model Maker. Aside those, the others are for audio manipulation, playing and visualizations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rwUA9u4oWoCR"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tflite_model_maker as mm\n",
        "from tflite_model_maker import audio_classifier\n",
        "import os\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "import itertools\n",
        "import glob\n",
        "import random\n",
        "\n",
        "from IPython.display import Audio, Image\n",
        "from scipy.io import wavfile\n",
        "\n",
        "print(f\"TensorFlow Version: {tf.__version__}\")\n",
        "print(f\"Model Maker Version: {mm.__version__}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HIfm2TxKZAuA"
      },
      "source": [
        "## The Birds dataset\n",
        "\n",
        "The Birds dataset is an education collection of 5 types of birds songs:\n",
        "\n",
        "- White-breasted Wood-Wren\n",
        "- House Sparrow\n",
        "- Red Crossbill\n",
        "- Chestnut-crowned Antpitta\n",
        "- Azara's Spinetail\n",
        "\n",
        "The original audio came from [Xeno-canto](https://www.xeno-canto.org/) which is a website dedicated to sharing bird sounds from all over the world.\n",
        "\n",
        "Let's start by downloading the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "upNRfilkNSmr"
      },
      "outputs": [],
      "source": [
        "birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',\n",
        "                                                'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',\n",
        "                                                cache_dir='./',\n",
        "                                                cache_subdir='dataset',\n",
        "                                                extract=True)\n",
        "                                                "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "441bbzZ5d6oq"
      },
      "source": [
        "## Explore the data\n",
        "\n",
        "The audios are already split in train and test folders. Inside each split folder, there's one folder for each bird, using their `bird_code` as name.\n",
        "\n",
        "The audios are all mono and with 16kHz sample rate.\n",
        "\n",
        "For more information about each file, you can read the `metadata.csv` file. It contains all the files authors, lincenses and some more information. You won't need to read it yourself on this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ayd7UqCfQQFU"
      },
      "outputs": [],
      "source": [
        "# @title [Run this] Util functions and data structures.\n",
        "\n",
        "data_dir = './dataset/small_birds_dataset'\n",
        "\n",
        "bird_code_to_name = {\n",
        "  'wbwwre1': 'White-breasted Wood-Wren',\n",
        "  'houspa': 'House Sparrow',\n",
        "  'redcro': 'Red Crossbill',  \n",
        "  'chcant2': 'Chestnut-crowned Antpitta',\n",
        "  'azaspi1': \"Azara's Spinetail\",   \n",
        "}\n",
        "\n",
        "birds_images = {\n",
        "  'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', # \tAlejandro Bayer Tamayo from Armenia, Colombia \n",
        "  'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', # \tDiliff\n",
        "  'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', #  Elaine R. Wilson, www.naturespicsonline.com\n",
        "  'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', # \tMike's Birds from Riverside, CA, US\n",
        "  'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368\n",
        "}\n",
        "\n",
        "test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))\n",
        "\n",
        "def get_random_audio_file():\n",
        "  test_list = glob.glob(test_files)\n",
        "  random_audio_path = random.choice(test_list)\n",
        "  return random_audio_path\n",
        "\n",
        "\n",
        "def show_bird_data(audio_path):\n",
        "  sample_rate, audio_data = wavfile.read(audio_path, 'rb')\n",
        "\n",
        "  bird_code = audio_path.split('/')[-2]\n",
        "  print(f'Bird name: {bird_code_to_name[bird_code]}')\n",
        "  print(f'Bird code: {bird_code}')\n",
        "  display(Image(birds_images[bird_code]))\n",
        "\n",
        "  plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'\n",
        "  plt.title(plttitle)\n",
        "  plt.plot(audio_data)\n",
        "  display(Audio(audio_data, rate=sample_rate))\n",
        "\n",
        "print('functions and data structures created')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yrv0uD7aXYl4"
      },
      "source": [
        "### Playing some audio\n",
        "\n",
        "To have a better understanding about the data, lets listen to a random audio files from the test split.\n",
        "\n",
        "Note: later in this notebook you'll run inference on this audio for testing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tEeMZh-VQy97"
      },
      "outputs": [],
      "source": [
        "random_audio = get_random_audio_file()\n",
        "show_bird_data(random_audio)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pQj1Mf7YZELS"
      },
      "source": [
        "## Training the Model\n",
        "\n",
        "When using Model Maker for audio, you have to start with a model spec. This is the base model that your new model will extract information to learn about the new classes. It also affects how the dataset will be transformed to respect the models spec parameters like: sample rate, number of channels.\n",
        "\n",
        "[YAMNet](https://tfhub.dev/google/yamnet/1) is an audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology.\n",
        "\n",
        "It's input is expected to be at 16kHz and with 1 channel.\n",
        "\n",
        "You don't need to do any resampling yourself. Model Maker takes care of that for you.\n",
        "\n",
        "- `frame_length` is to decide how long each traininng sample is. in this caase EXPECTED_WAVEFORM_LENGTH * 3s\n",
        "\n",
        "- `frame_steps` is to decide how far appart are the training samples. In this case, the ith sample will start at EXPECTED_WAVEFORM_LENGTH * 6s after the (i-1)th sample.\n",
        "\n",
        "The reason to set these values is to work around some limitation in real world dataset.\n",
        "\n",
        "For example, in the bird dataset, birds don't sing all the time. They sing, rest and sing again, with noises in between. Having a long frame would help capture the singing, but setting it too long will reduce the number of samples for training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tUcxtfHXY7XS"
      },
      "outputs": [],
      "source": [
        "spec = audio_classifier.YamNetSpec(\n",
        "    keep_yamnet_and_custom_heads=True,\n",
        "    frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,\n",
        "    frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EF185yZ_M7zu"
      },
      "source": [
        "## Loading the data\n",
        "\n",
        "Model Maker has the API to load the data from a folder and have it in the expected format for the model spec.\n",
        "\n",
        "The train and test split are based on the folders. The validation dataset will be created as 20% of the train split.\n",
        "\n",
        "Note: The `cache=True` is important to make training later faster but it will also require more RAM to hold the data. For the birds dataset that is not a problem since it's only 300MB, but if you use your own data you have to pay attention to it.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cX0RqETqZgzo"
      },
      "outputs": [],
      "source": [
        "train_data = audio_classifier.DataLoader.from_folder(\n",
        "    spec, os.path.join(data_dir, 'train'), cache=True)\n",
        "train_data, validation_data = train_data.split(0.8)\n",
        "test_data = audio_classifier.DataLoader.from_folder(\n",
        "    spec, os.path.join(data_dir, 'test'), cache=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ziMghju-Rts2"
      },
      "source": [
        "## Training the model\n",
        "\n",
        "the audio_classifier has the [`create`](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/audio_classifier/create) method that creates a model and already start training it. \n",
        "\n",
        "You can customize many parameterss, for more information you can read more details in the documentation.\n",
        "\n",
        "On this first try you'll use all the default configurations and train for 100 epochs.\n",
        "\n",
        "Note: The first epoch takes longer than all the other ones because it's when the cache is created. After that each epoch takes close to 1 second."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8r6Awvl4ZkIv"
      },
      "outputs": [],
      "source": [
        "batch_size = 128\n",
        "epochs = 100\n",
        "\n",
        "print('Training the model')\n",
        "model = audio_classifier.create(\n",
        "    train_data,\n",
        "    spec,\n",
        "    validation_data,\n",
        "    batch_size=batch_size,\n",
        "    epochs=epochs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oXMEHZkAxJTl"
      },
      "source": [
        "The accuracy looks good but it's important to run the evaluation step on the test data and vefify your model achieved good results on unseed data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xEM1aRNKtvHS"
      },
      "source": [
        ""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GDoQACMrZnOx"
      },
      "outputs": [],
      "source": [
        "print('Evaluating the model')\n",
        "model.evaluate(test_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8QRRAM39aOxS"
      },
      "source": [
        "## Understanding your model\n",
        "\n",
        "When training a classifier, it's useful to see the [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix). The confusion matrix gives you detailed knowledge of how your classifier is performing on test data.\n",
        "\n",
        "Model Maker already creates the confusion matrix for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zqB3c0368iH3"
      },
      "outputs": [],
      "source": [
        "def show_confusion_matrix(confusion, test_labels):\n",
        "  \"\"\"Compute confusion matrix and normalize.\"\"\"\n",
        "  confusion_normalized = confusion.astype(\"float\") / confusion.sum(axis=1)\n",
        "  axis_labels = test_labels\n",
        "  ax = sns.heatmap(\n",
        "      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,\n",
        "      cmap='Blues', annot=True, fmt='.2f', square=True)\n",
        "  plt.title(\"Confusion matrix\")\n",
        "  plt.ylabel(\"True label\")\n",
        "  plt.xlabel(\"Predicted label\")\n",
        "\n",
        "confusion_matrix = model.confusion_matrix(test_data)\n",
        "show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7gr1s7juBy7H"
      },
      "source": [
        "## Testing the model [Optional]\n",
        "\n",
        "You can try the model on a sample audio from the test dataset just to see the results.\n",
        "\n",
        "First you get the serving model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PmlmTl42Bq_u"
      },
      "outputs": [],
      "source": [
        "serving_model = model.create_serving_model()\n",
        "\n",
        "print(f'Model\\'s input shape and type: {serving_model.inputs}')\n",
        "print(f'Model\\'s output shape and type: {serving_model.outputs}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQsZFO2mrYhx"
      },
      "source": [
        "Coming back to the random audio you loaded earlier"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8dv5ViK0reXc"
      },
      "outputs": [],
      "source": [
        "# if you want to try another file just uncoment the line below\n",
        "random_audio = get_random_audio_file()\n",
        "show_bird_data(random_audio)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uixOfKSUj_9m"
      },
      "source": [
        "The model created has a fixed input window. \n",
        "\n",
        "For a given audio file, you'll have to split it in windows of data of the expected size. The last window might need to be filled with zeros."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YAvGKQL0lNty"
      },
      "outputs": [],
      "source": [
        "sample_rate, audio_data = wavfile.read(random_audio, 'rb')\n",
        "\n",
        "audio_data = np.array(audio_data) / tf.int16.max\n",
        "input_size = serving_model.input_shape[1]\n",
        "\n",
        "splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)\n",
        "\n",
        "print(f'Test audio path: {random_audio}')\n",
        "print(f'Original size of the audio data: {len(audio_data)}')\n",
        "print(f'Number of windows for inference: {len(splitted_audio_data)}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PLxKd0eFkMcR"
      },
      "source": [
        "You'll loop over all the splitted audio and apply the model for each one of them.\n",
        "\n",
        "The model you've just trained has 2 outputs: The original YAMNet's output and the one you've just trained. This is important because the real world environment is more complicated than just bird sounds. You can use the YAMNet's output to filter out non relevant audio, for example, on the birds use case, if YAMNet is not classifying Birds or Animals, this might show that the output from your model might have an irrelevant classification.\n",
        "\n",
        "\n",
        "Below both outpus are printed to make it easier to understand their relation. Most of the mistakes that your model make are when YAMNet's prediction is not related to your domain (eg: birds)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4-8fJLrxGwYT"
      },
      "outputs": [],
      "source": [
        "print(random_audio)\n",
        "\n",
        "results = []\n",
        "print('Result of the window ith:  your model class -\u003e score,  (spec class -\u003e score)')\n",
        "for i, data in enumerate(splitted_audio_data):\n",
        "  yamnet_output, inference = serving_model(data)\n",
        "  results.append(inference[0].numpy())\n",
        "  result_index = tf.argmax(inference[0])\n",
        "  spec_result_index = tf.argmax(yamnet_output[0])\n",
        "  t = spec._yamnet_labels()[spec_result_index]\n",
        "  result_str = f'Result of the window {i}: ' \\\n",
        "  f'\\t{test_data.index_to_label[result_index]} -\u003e {inference[0][result_index].numpy():.3f}, ' \\\n",
        "  f'\\t({spec._yamnet_labels()[spec_result_index]} -\u003e {yamnet_output[0][spec_result_index]:.3f})'\n",
        "  print(result_str)\n",
        "\n",
        "\n",
        "results_np = np.array(results)\n",
        "mean_results = results_np.mean(axis=0)\n",
        "result_index = mean_results.argmax()\n",
        "print(f'Mean result: {test_data.index_to_label[result_index]} -\u003e {mean_results[result_index]}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yASrikBgZ9ZO"
      },
      "source": [
        "## Exporting the model\n",
        "\n",
        "The last step is exporting your model to be used on embedded devices or on the browser.\n",
        "\n",
        "The `export` method export both formats for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xw_ehPxAdQlz"
      },
      "outputs": [],
      "source": [
        "models_path = './birds_models'\n",
        "print(f'Exporing the TFLite model to {models_path}')\n",
        "\n",
        "model.export(models_path, tflite_filename='my_birds_model.tflite')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jjZRKmurA3y_"
      },
      "source": [
        "You can also export the SavedModel version for serving or using on a Python environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "veBwppOsA-kn"
      },
      "outputs": [],
      "source": [
        "model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5xr0idac6xfi"
      },
      "source": [
        "## Next Steps\n",
        "\n",
        "You did it.\n",
        "\n",
        "Now your new model can be deployed on  mobile devices using [TFLite AudioClassifier Task API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/audio_classifier).\n",
        "\n",
        "You can also try the same process with your own data with different classes and here is the documentation for [Model Maker for Audio Classification](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/audio_classifier).\n",
        "\n",
        "Also learn from end-to-end reference apps: [Android](https://github.com/tensorflow/examples/tree/master/lite/examples/sound_classification/android/), [iOS](https://github.com/tensorflow/examples/tree/master/lite/examples/sound_classification/ios)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "model_maker_audio_classification.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
